{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "56341c66",
   "metadata": {},
   "source": [
    "# Chapter 3: High-level Language-Specific Analysis and Transformation\n",
    "\n",
    "As we saw in the previous chapter, the IR generated from the input program has many\n",
    "opportunities for optimisation. In this chapter, we'll implement three optimisations:\n",
    "\n",
    "1. Removing redundant reshapes\n",
    "2. Reshaping constants during compilation time\n",
    "3. Eliminating operations whose results are not used\n",
    "\n",
    "Let's take a look again at our example input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e07ae44f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "builtin.module {\n",
      "  \"toy.func\"() ({\n",
      "    %0 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %1 = \"toy.reshape\"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    %2 = \"toy.constant\"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>\n",
      "    %3 = \"toy.reshape\"(%2) : (tensor<6xf64>) -> tensor<6xf64>\n",
      "    %4 = \"toy.reshape\"(%3) : (tensor<6xf64>) -> tensor<2x3xf64>\n",
      "    %5 = \"toy.add\"(%1, %4) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    \"toy.print\"(%5) : (tensor<2x3xf64>) -> ()\n",
      "    \"toy.return\"() : () -> ()\n",
      "  }) {sym_name = \"main\", function_type = () -> ()} : () -> ()\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "from toy.compiler import parse_toy\n",
    "\n",
    "from xdsl.printer import Printer\n",
    "\n",
    "example = \"\"\"\n",
    "def main() {\n",
    "  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];\n",
    "  var b<6> = [1, 2, 3, 4, 5, 6];\n",
    "  var c<2, 3> = b;\n",
    "  var d = a + c;\n",
    "  print(d);\n",
    "}\n",
    "\"\"\"\n",
    "\n",
    "module = parse_toy(example)\n",
    "Printer().print_op(module)\n",
    "print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a3697e4a",
   "metadata": {},
   "source": [
    "## Redundant Reshapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6777bb72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "builtin.module {\n",
      "  \"toy.func\"() ({\n",
      "    %0 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %1 = \"toy.reshape\"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    %2 = \"toy.constant\"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>\n",
      "    %3 = \"toy.reshape\"(%2) : (tensor<6xf64>) -> tensor<6xf64>\n",
      "    %4 = \"toy.reshape\"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>\n",
      "    %5 = \"toy.add\"(%1, %4) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    \"toy.print\"(%5) : (tensor<2x3xf64>) -> ()\n",
      "    \"toy.return\"() : () -> ()\n",
      "  }) {sym_name = \"main\", function_type = () -> ()} : () -> ()\n",
      "}"
     ]
    }
   ],
   "source": [
    "from toy.dialects import toy\n",
    "\n",
    "from xdsl.ir import OpResult\n",
    "from xdsl.pattern_rewriter import (\n",
    "    PatternRewriter,\n",
    "    PatternRewriteWalker,\n",
    "    RewritePattern,\n",
    "    op_type_rewrite_pattern,\n",
    ")\n",
    "\n",
    "\n",
    "class ReshapeReshapeOptPattern(RewritePattern):\n",
    "    @op_type_rewrite_pattern\n",
    "    def match_and_rewrite(self, op: toy.ReshapeOp, rewriter: PatternRewriter):\n",
    "        \"\"\"\n",
    "        Reshape(Reshape(x)) = Reshape(x)\n",
    "        \"\"\"\n",
    "        # Look at the input of the current reshape.\n",
    "        reshape_input = op.arg\n",
    "        if not isinstance(reshape_input, OpResult):\n",
    "            # Input was not produced by an operation, could be a function argument\n",
    "            return\n",
    "\n",
    "        reshape_input_op = reshape_input.op\n",
    "        if not isinstance(reshape_input_op, toy.ReshapeOp):\n",
    "            # Input defined by another transpose? If not, no match.\n",
    "            return\n",
    "\n",
    "        new_op = toy.ReshapeOp(reshape_input_op.arg, op.res.type)\n",
    "        rewriter.replace_matched_op(new_op)\n",
    "\n",
    "\n",
    "# Use `PatternRewriteWalker` to rewrite all matched operations\n",
    "PatternRewriteWalker(ReshapeReshapeOptPattern()).rewrite_module(module)\n",
    "Printer().print_op(module)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "dd49fd84",
   "metadata": {},
   "source": [
    "This looks very similar to what we had before, but is subtly different. Importantly,\n",
    "the reshape that assigns to %4 now takes %2 as input, instead of %3. %3 is now no longer\n",
    "used, and because it's an operation with no observable side-effects, we can avoid doing\n",
    "the work altogether."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c509c618",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "builtin.module {\n",
      "  \"toy.func\"() ({\n",
      "    %0 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %1 = \"toy.reshape\"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    %2 = \"toy.constant\"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>\n",
      "    %3 = \"toy.reshape\"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>\n",
      "    %4 = \"toy.add\"(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    \"toy.print\"(%4) : (tensor<2x3xf64>) -> ()\n",
      "    \"toy.return\"() : () -> ()\n",
      "  }) {sym_name = \"main\", function_type = () -> ()} : () -> ()\n",
      "}"
     ]
    }
   ],
   "source": [
    "from xdsl.transforms.dead_code_elimination import RemoveUnusedOperations\n",
    "\n",
    "PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(module)\n",
    "\n",
    "Printer().print_op(module)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e93ecb70",
   "metadata": {},
   "source": [
    "## Fold Constant Reshaping\n",
    "\n",
    "One more opportunity for optimisation is to reshape the constants at compile-time,\n",
    "instead of at runtime. We can do this with another custom `RewritePattern`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e1ffdc6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "builtin.module {\n",
      "  \"toy.func\"() ({\n",
      "    %0 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %1 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %2 = \"toy.constant\"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>\n",
      "    %3 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %4 = \"toy.add\"(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    \"toy.print\"(%4) : (tensor<2x3xf64>) -> ()\n",
      "    \"toy.return\"() : () -> ()\n",
      "  }) {sym_name = \"main\", function_type = () -> ()} : () -> ()\n",
      "}"
     ]
    }
   ],
   "source": [
    "from xdsl.dialects.builtin import DenseIntOrFPElementsAttr\n",
    "from xdsl.utils.hints import isa\n",
    "\n",
    "\n",
    "class FoldConstantReshapeOptPattern(RewritePattern):\n",
    "    @op_type_rewrite_pattern\n",
    "    def match_and_rewrite(self, op: toy.ReshapeOp, rewriter: PatternRewriter):\n",
    "        \"\"\"\n",
    "        Reshaping a constant can be done at compile time\n",
    "        \"\"\"\n",
    "        # Look at the input of the current reshape.\n",
    "        reshape_input = op.arg\n",
    "        if not isinstance(reshape_input, OpResult):\n",
    "            # Input was not produced by an operation, could be a function argument\n",
    "            return\n",
    "\n",
    "        reshape_input_op = reshape_input.op\n",
    "        if not isinstance(reshape_input_op, toy.ConstantOp):\n",
    "            # Input defined by another transpose? If not, no match.\n",
    "            return\n",
    "\n",
    "        assert isa(op.res.type, toy.TensorTypeF64)\n",
    "\n",
    "        new_value = DenseIntOrFPElementsAttr.from_list(\n",
    "            type=op.res.type, data=reshape_input_op.value.get_values()\n",
    "        )\n",
    "        new_op = toy.ConstantOp(new_value)\n",
    "        rewriter.replace_matched_op(new_op)\n",
    "\n",
    "\n",
    "PatternRewriteWalker(FoldConstantReshapeOptPattern()).rewrite_module(module)\n",
    "Printer().print_op(module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af69eccd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "builtin.module {\n",
      "  \"toy.func\"() ({\n",
      "    %0 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %1 = \"toy.constant\"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>\n",
      "    %2 = \"toy.add\"(%0, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>\n",
      "    \"toy.print\"(%2) : (tensor<2x3xf64>) -> ()\n",
      "    \"toy.return\"() : () -> ()\n",
      "  }) {sym_name = \"main\", function_type = () -> ()} : () -> ()\n",
      "}"
     ]
    }
   ],
   "source": [
    "# Remove now unused original constants\n",
    "PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(module)\n",
    "Printer().print_op(module)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5567155c",
   "metadata": {},
   "source": [
    "Now that we've done all the optimisations we could on this level of abstraction, let's\n",
    "go one level lower towards RISC-V."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
